/*
 * Copyright 2013-2020 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.flying.fish.gateway.filter;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import com.flying.fish.gateway.event.WeightRemoveApplicationEvent;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.gateway.event.PredicateArgsEvent;
import org.springframework.cloud.gateway.event.RefreshRoutesEvent;
import org.springframework.cloud.gateway.event.WeightDefinedEvent;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.support.ConfigurationService;
import org.springframework.cloud.gateway.support.WeightConfig;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.Ordered;
import org.springframework.core.style.ToStringCreator;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.WEIGHT_ATTR;

/**
 * @Description 此过滤器是直接复制gateway中的WeightCalculatorWebFilter过滤器源代码的基础上，增加了WeightRemoveApplicationEvent监听事件，
 *                对负载权重group数据进行清理。以解决原gateway中对同一个负载下多个网关路由，其中一个移除后，分组中的网关路由权重并未清除，
 *                导制随机数计算分配时，还会匹配到该下线网关路由中，最终出现404问题
 * @Author JL
 * @Date 2021/10/12
 * @Version V1.0
 */
/**
 * @author Spencer Gibb
 * @author Alexey Nakidkin
 */
@Slf4j
public class CustomWeightCalculatorWebFilter implements WebFilter, Ordered, SmartApplicationListener {

    /**
     * Order of Weight Calculator Web filter.
     */
    public static final int WEIGHT_CALC_FILTER_ORDER = 10001;

//    private static final Log log = LogFactory.getLog(CustomWeightCalculatorWebFilter.class);

    private static final Logger logger = LoggerFactory.getLogger(CustomWeightCalculatorWebFilter.class);

    private final ObjectProvider<RouteLocator> routeLocator;

    private final ConfigurationService configurationService;

    private Random random = new Random();

    private int order = WEIGHT_CALC_FILTER_ORDER;

    private Map<String, GroupWeightConfig> groupWeights = new ConcurrentHashMap<>();

    private final AtomicBoolean routeLocatorInitialized = new AtomicBoolean();

    public CustomWeightCalculatorWebFilter(ObjectProvider<RouteLocator> routeLocator,
                                     ConfigurationService configurationService) {
        this.routeLocator = routeLocator;
        this.configurationService = configurationService;
    }

    /* for testing */
    static Map<String, String> getWeights(ServerWebExchange exchange) {
        Map<String, String> weights = exchange.getAttribute(WEIGHT_ATTR);

        if (weights == null) {
            weights = new ConcurrentHashMap<>();
            exchange.getAttributes().put(WEIGHT_ATTR, weights);
        }
        return weights;
    }

    @Override
    public int getOrder() {
        return order;
    }

    public void setOrder(int order) {
        this.order = order;
    }

    public void setRandom(Random random) {
        this.random = random;
    }

    @Override
    public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
        // from config file
        return PredicateArgsEvent.class.isAssignableFrom(eventType) ||
                // from java dsl
                WeightDefinedEvent.class.isAssignableFrom(eventType) ||
                // force initialization
                RefreshRoutesEvent.class.isAssignableFrom(eventType) ||
                // remove weight group
                WeightRemoveApplicationEvent.class.isAssignableFrom(eventType);
    }

    @Override
    public boolean supportsSourceType(Class<?> sourceType) {
        return true;
    }

    @Override
    public void onApplicationEvent(ApplicationEvent event) {
        if (event instanceof PredicateArgsEvent) {
            handle((PredicateArgsEvent) event);
        }
        else if (event instanceof WeightDefinedEvent) {
            addWeightConfig(((WeightDefinedEvent) event).getWeightConfig());
        }
        else if (event instanceof RefreshRoutesEvent && routeLocator != null) {
            // forces initialization
            if (routeLocatorInitialized.compareAndSet(false, true)) {
                // on first time, block so that app fails to start if there are errors in
                // routes
                // see gh-1574
                routeLocator.ifAvailable(locator -> locator.getRoutes().blockLast());
            }
            else {
                // this preserves previous behaviour on refresh, this could likely go away
                routeLocator.ifAvailable(locator -> locator.getRoutes().subscribe());
            }
        }
        else if (event instanceof WeightRemoveApplicationEvent) {
            // remove group
            groupWeights.remove(((WeightRemoveApplicationEvent) event).getGroup());
        }

    }

    public void handle(PredicateArgsEvent event) {
        Map<String, Object> args = event.getArgs();

        if (args.isEmpty() || !hasRelevantKey(args)) {
            return;
        }

        WeightConfig config = new WeightConfig(event.getRouteId());

        this.configurationService.with(config).name(WeightConfig.CONFIG_PREFIX).normalizedProperties(args).bind();

        addWeightConfig(config);
    }

    private boolean hasRelevantKey(Map<String, Object> args) {
        return args.keySet().stream().anyMatch(key -> key.startsWith(WeightConfig.CONFIG_PREFIX + "."));
    }

    /* for testing */ void addWeightConfig(WeightConfig weightConfig) {
        String group = weightConfig.getGroup();
        GroupWeightConfig config;
        // only create new GroupWeightConfig rather than modify
        // and put at end of calculations. This avoids concurency problems
        // later during filter execution.
        if (groupWeights.containsKey(group)) {
            config = new GroupWeightConfig(groupWeights.get(group));
        }
        else {
            config = new GroupWeightConfig(group);
        }

        config.weights.put(weightConfig.getRouteId(), weightConfig.getWeight());

        // recalculate

        // normalize weights
        int weightsSum = 0;

        for (Integer weight : config.weights.values()) {
            weightsSum += weight;
        }

        final AtomicInteger index = new AtomicInteger(0);
        for (Map.Entry<String, Integer> entry : config.weights.entrySet()) {
            String routeId = entry.getKey();
            Integer weight = entry.getValue();
            Double nomalizedWeight = weight / (double) weightsSum;
            config.normalizedWeights.put(routeId, nomalizedWeight);

            // recalculate rangeIndexes
            config.rangeIndexes.put(index.getAndIncrement(), routeId);
        }

        // TODO: calculate ranges
        config.ranges.clear();

        config.ranges.add(0.0);

        List<Double> values = new ArrayList<>(config.normalizedWeights.values());
        for (int i = 0; i < values.size(); i++) {
            Double currentWeight = values.get(i);
            Double previousRange = config.ranges.get(i);
            Double range = previousRange + currentWeight;
            config.ranges.add(range);
        }

        if (log.isTraceEnabled()) {
            log.trace("Recalculated group weight config " + config);
        }
        // only update after all calculations
        groupWeights.put(group, config);
    }

    /* for testing */ Map<String, GroupWeightConfig> getGroupWeights() {
        return groupWeights;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {

        log.info("web filter invoke!");
        logger.info("web filter invoke!");
        Map<String, String> weights = getWeights(exchange);

        for (String group : groupWeights.keySet()) {
            GroupWeightConfig config = groupWeights.get(group);

            if (config == null) {
                if (log.isDebugEnabled()) {
                    log.debug("No GroupWeightConfig found for group: " + group);
                }
                continue; // nothing we can do, but this is odd
            }

            double r = this.random.nextDouble();

            List<Double> ranges = config.ranges;

            if (log.isTraceEnabled()) {
                log.trace("Weight for group: " + group + ", ranges: " + ranges + ", r: " + r);
            }

            for (int i = 0; i < ranges.size() - 1; i++) {
                if (r >= ranges.get(i) && r < ranges.get(i + 1)) {
                    String routeId = config.rangeIndexes.get(i);
                    weights.put(group, routeId);
                    break;
                }
            }
        }

        if (log.isTraceEnabled()) {
            log.trace("Weights attr: " + weights);
        }

        return chain.filter(exchange);
    }

    /* for testing */ static class GroupWeightConfig {

        String group;

        LinkedHashMap<String, Integer> weights = new LinkedHashMap<>();

        LinkedHashMap<String, Double> normalizedWeights = new LinkedHashMap<>();

        LinkedHashMap<Integer, String> rangeIndexes = new LinkedHashMap<>();

        List<Double> ranges = new ArrayList<>();

        GroupWeightConfig(String group) {
            this.group = group;
        }

        GroupWeightConfig(GroupWeightConfig other) {
            this.group = other.group;
            this.weights = new LinkedHashMap<>(other.weights);
            this.normalizedWeights = new LinkedHashMap<>(other.normalizedWeights);
            this.rangeIndexes = new LinkedHashMap<>(other.rangeIndexes);
        }

        @Override
        public String toString() {
            return new ToStringCreator(this).append("group", group).append("weights", weights)
                    .append("normalizedWeights", normalizedWeights).append("rangeIndexes", rangeIndexes).toString();
        }

    }

}
